use crate::common::{
    symcodec::SymbolCodec, 
    symbol::SymbolWithin8Bits,
    bitio::{BitsRead, BitsWrite}
};
use core::slice;
use std::io::{Error as IOError, Read, Write, Seek, SeekFrom};
use ahash::AHashMap;

/// Variable code length (up to 16 bits) LZW codec. 
/// Reference:
/// - [Wikipedia](https://en.wikipedia.org/wiki/Lempel%E2%80%93Ziv%E2%80%93Welch)
/// - [Welch 84](http://www.cs.duke.edu/courses/spring03/cps296.5/papers/welch_1984_technique_for.pdf)
pub struct LZWCodec<S> {
    symbol_bits: u8,
    initial_codelen: u8,
    maximum_codelen: u8,
    initial_next_code: u16,
    opt_selected_symbols: Option<Box<[S]>>
}

impl<S: SymbolWithin8Bits> LZWCodec<S> {
    pub fn build_for_full_nbits_symbol(maximum_codelen: u8) -> Self {
        let symbol_bits = S::bits_needed();
        if maximum_codelen <= symbol_bits || maximum_codelen > 16 {
            panic!("maximum codelen should be between the number of bits all symbols needed and 16!")
        }
        Self {
            symbol_bits,
            maximum_codelen, 
            initial_codelen:  symbol_bits + 1, 
            initial_next_code: 1 << symbol_bits, 
            opt_selected_symbols: None
        }
    }

    pub fn build_for_selected_nbits_symbol(selected: Vec<S>, maximum_codelen: u8) -> Self {
        let symbol_bits = S::bits_needed();
        if maximum_codelen <= symbol_bits || maximum_codelen > 16 {
            panic!("maximum codelen should be between the number of bits all symbols needed and 16!")
        }
        let n_selected = selected.len();
        if n_selected >= (1 << symbol_bits) {
            panic!("the number of selected symbols >= the number of possible symbols!")
        }
        Self { 
            symbol_bits,
            maximum_codelen,
            initial_codelen: (n_selected as f32).log2().floor() as u8 + 1,
            initial_next_code: n_selected as u16, 
            opt_selected_symbols: Some(selected.into_boxed_slice())
        }
    }
}

impl<S: SymbolWithin8Bits> SymbolCodec for LZWCodec<S> {
    fn encode<R, W>(
        &self, 
        i_byte_reader: &mut R,
        o_byte_writer: &mut W, 
        symbol_read_limit: u32
    ) -> Result<u32, IOError>
    where 
        R: Read, 
        W: Write + Seek
    {
        // encoded file structure:
        // byte 0: symbol len in bits
        // byte 1: initial code length
        // byte 2: maximum code length
        // byte 3 & 4: initial next code
        // byte 5 & 6: constructed table size
        // byte 7..10: number of symbols encoded
        // byte 11..EOF: payload data

        if symbol_read_limit == 0 {
            return Ok(0)
        }

        o_byte_writer.write_all(slice::from_ref(&self.symbol_bits))?;       
        o_byte_writer.write_all(&self.initial_codelen.to_le_bytes())?;      
        o_byte_writer.write_all(&self.maximum_codelen.to_le_bytes())?;      
        o_byte_writer.write_all(&self.initial_next_code.to_le_bytes())?;    
        o_byte_writer.write_all(&[0u8;2])?;                                  
        o_byte_writer.write_all(&[0u8;4])?;                                  

        let selected_symbol_codetable = if let Some(ref symbols) = self.opt_selected_symbols {
            let mut table = AHashMap::with_capacity(symbols.len());
            let mut bbuf = [symbols.len() as u8]; // note that symbols.len() should not exceeds 255
            o_byte_writer.write_all(&bbuf)?;
            for (i, s) in symbols.iter().enumerate() {
                bbuf[0] = s.val_in_u8();
                table.insert(bbuf[0], i as u16);
                o_byte_writer.write_all(&bbuf)?;
            }
            Some(table)
        } else {
            None
        };

        let mut bitsreader = BitsRead::create_with(i_byte_reader);
        let mut bitswriter = BitsWrite::create_with(o_byte_writer);

        let mut constructed_table: AHashMap<(u16, u8), u16> = AHashMap::new();
        let mut omega: Option<u16> = None;
        let mut b = 0u8;
        let mut growable = true;
        let mut codelen = self.initial_codelen;
        let mut next_code = self.initial_next_code as u32;
        
        let mut passed = 0u32;

        let f_get_symbol_code = |sym: u8| {
            if let Some(ref t) = selected_symbol_codetable {
                *t.get(&sym).expect("Partial n-bit symbol codec encounters unrecorded symbol!")
            } else {
                sym as u16
            }
        };

        while passed < symbol_read_limit && bitsreader.read_bits_8_8_unchecked(&mut b, self.symbol_bits)? {
            if let Some(omega_code) = omega {
                let omega_b = (omega_code, b);
                if let Some(&omega_b_code) = constructed_table.get(&omega_b) {
                    // omega_b is already in the table
                    omega = Some(omega_b_code)
                } else {
                    // omega_b is not in the table
                    bitswriter.write_bits_32_8_unchecked(omega_code as u32, codelen)?;
                    if growable {
                        let need_extend = next_code == (1 << codelen);
                        if need_extend && (codelen == self.maximum_codelen) {
                            growable = false
                        } else {
                            constructed_table.insert(omega_b, next_code as u16);
                            if need_extend {
                                codelen += 1;
                            }
                            next_code += 1;
                        }
                    }
                    omega = Some(f_get_symbol_code(b))
                }
            } else {
                omega = Some(f_get_symbol_code(b))
            }
            passed += 1;
        }
        bitswriter.write_bits_32_8_unchecked(omega.unwrap() as u32, codelen)?;
        bitswriter.flush_buffered_bits()?;

        o_byte_writer.seek(SeekFrom::Start(5))?;
        o_byte_writer.write_all(&(constructed_table.len() as u16).to_le_bytes())?;
        o_byte_writer.write_all(&passed.to_le_bytes())?;

        Ok(passed)
    }

    fn decode<R, W>(
        i_byte_reader: &mut R,
        o_byte_writer: &mut W,
    ) -> Result<u32, IOError>
    where
        R: Read + Seek,
        W: Write
    {
        let mut onebyte_buf = [0u8];
        let mut twobyte_buf = [0u8;2];
        let mut fourbyte_buf = [0u8;4];

        i_byte_reader.read_exact(&mut onebyte_buf)?;
        let symbol_bits = onebyte_buf[0];
        i_byte_reader.read_exact(&mut onebyte_buf)?;
        let initial_codelen = onebyte_buf[0];
        i_byte_reader.read_exact(&mut onebyte_buf)?;
        let maximum_codelen = onebyte_buf[0];
        i_byte_reader.read_exact(&mut twobyte_buf)?;
        let initial_next_code = u16::from_le_bytes(twobyte_buf);
        i_byte_reader.read_exact(&mut twobyte_buf)?;
        let constructed_table_len = u16::from_le_bytes(twobyte_buf);
        i_byte_reader.read_exact(&mut fourbyte_buf)?;
        let symbol_counts = u32::from_le_bytes(fourbyte_buf);

        let selected_symbol_codetable = if initial_next_code < (1 << symbol_bits) {
            i_byte_reader.read_exact(&mut onebyte_buf)?;
            let n_selected_symbols = onebyte_buf[0];
            let mut v = Vec::with_capacity(n_selected_symbols as usize);
            for _ in 0..n_selected_symbols {
                i_byte_reader.read_exact(&mut onebyte_buf)?;
                v.push(onebyte_buf[0]);
            }
            Some(v.into_boxed_slice())
        } else {
            None
        };

        let mut bitsreader = BitsRead::create_with(i_byte_reader);
        let mut bitswriter = BitsWrite::create_with(o_byte_writer);

        let mut constructed_table: Vec<Vec<u8>> = Vec::with_capacity(constructed_table_len as usize);
        let mut gamma_code = 0u16;
        let mut delta_code = 0u16;
        let mut next_code = initial_next_code as u32;
        let mut codelen = initial_codelen;
        let mut passed = 0u32;

        if !bitsreader.read_bits_16_128_unchecked(&mut gamma_code, codelen)? {
            return Ok(0)
        }

        loop {
            // Output gamma and figure out what is the postfix symbol X in string 'gammaX' that made
            // the encoder output gamma and insert 'gammaX' into the table as the encoder did.
            // We know that X is the first character of the next string after gamma (call it delta),
            // so if the next string can be found in the table, we are done, otherwise, it turns out
            // that X is also the first character of gamma when delta can not be found in the table
            // right now.

            // get & output the string gamma:
            let mut gamma = 
            if gamma_code < initial_next_code {
                let sym = if let Some(ref t) = selected_symbol_codetable {
                    t[gamma_code as usize]
                } else {
                    gamma_code as u8
                };
                if symbol_bits == 8 {
                    bitswriter.write_bytes(slice::from_ref(&sym))?;
                } else {
                    bitswriter.write_bits_32_8_unchecked(sym as u32, symbol_bits)?;
                }
                passed += 1;
                vec![sym]
            } else {
                let s = constructed_table.get((gamma_code - initial_next_code) as usize).unwrap();
                if symbol_bits == 8 {
                    bitswriter.write_bytes(s)?;
                } else {
                    for sym in s {
                        bitswriter.write_bits_32_8_unchecked(*sym as u32, symbol_bits)?;
                    }
                }
                passed += s.len() as u32;
                s.clone()
            };
            
            // have we reached the end and just outputed the last symbol?
            if passed >= symbol_counts {
                break;
            }

            // try to read the code for the next string dalta:
            let delta_codelen = 
                if next_code == (1 << codelen) && codelen < maximum_codelen 
                { codelen + 1 }
                else 
                { codelen };
            if !bitsreader.read_bits_16_128_unchecked(&mut delta_code, delta_codelen)? {
                break;
            }
            codelen = delta_codelen;
            next_code += 1;

            if constructed_table.len() >= constructed_table_len.into() {
                gamma_code = delta_code;
                continue;
            }

            // get the symbol X, finally:
            let postfix_symbol_x = 
                if delta_code < initial_next_code {
                    if let Some(ref t) = selected_symbol_codetable {
                        t[delta_code as usize]
                    } else {
                        delta_code as u8
                    }
                } else if let Some(ref delta) = constructed_table.get((delta_code - initial_next_code) as usize) {
                    delta[0]
                } else {
                    gamma[0]
                };

            // add 'gammaX' to the table:
            gamma.push(postfix_symbol_x);
            constructed_table.push(gamma);

            // prepare for the next iteration
            gamma_code = delta_code;
        }
        Ok(passed)
    }
}
